# -*- coding: utf-8 -*-
import pickle
import numpy as np
import pandas as pd
from sklearn import cluster

def clustering(X, eps, min_samples, coef, cluster_targets):
    '''
    DBSCAN 聚类
    '''
    data = X.copy()
    # 对聚类目标施加一个权重
    for target in cluster_targets:
        data[target] = coef*data[target]
    
    db = cluster.DBSCAN(eps=eps, min_samples=min_samples, n_jobs=-1)
    db.fit(data)
    labels = db.labels_
    return labels
    
    
def cluster_param_select(data, \
                        scaler, \
                        epsilon_list, \
                        min_samples_list, \
                        coef_list, \
                        cluster_targets=['Plx', 'B-V']):
    '''
    筛选聚类最佳参数
    '''
    # 临时变量
    value_list = []
    params_list = [] 
    params_dict = dict()
    # 提取那些 Plx 落在 [20，22] 上的行号（索引号）
    columns = X.columns
    X_before = scaler.inverse_transform(X)
    X_before = pd.DataFrame(data=X_before, columns=columns)
    Plx_20_22 = X_before.loc[(20<=X_before['Plx']) & (X_before['Plx']<=22)].index
    # 产生参数网格。
    for eps in epsilon_list:
        for min_samples in min_samples_list:
            for coef in coef_list:
                labels = clustering(X, eps, min_samples, coef, cluster_targets)
                labels = labels[Plx_20_22]
                # 聚类簇
                clusters = np.unique(labels)
                # 聚类簇数量
                num_clusters = len(clusters)
                if num_clusters == 1:
                    continue
                # 
                clu_ind_dict = dict()
                for label in labels:
                    clu_ind_dict[label] = clu_ind_dict.get(label, 0) \
                                            + 1

                # 主要簇的个体数：
                num_main_clu = sorted(clu_ind_dict.items(), key = lambda x: x[1],\
                                    reverse=True)
                if num_main_clu == []:
                    continue
                if num_main_clu[0][0] == -1:
                    # 包含个体数最多的簇不能是 -1（-1 是游离个体）
                    continue
                max_num = num_main_clu[0][1]

                if -1 in clusters:
                    # 游离个体不能被当成簇
                    num_clusters -= 1

                # create list
                value_tuple = tuple((max_num, num_clusters))
                params_tuple = tuple((eps, min_samples, coef))

                params_dict[params_tuple] = value_tuple
                
    # 根据原则 1、原则 2 排序。
    max_param = sorted(params_dict.items(), \
                        key=lambda x: (x[1][1], x[1][0]), \
                        reverse=True)
                        
    # 返回最佳参数
    max_param = max_param[0]
    best_eps = max_param[0][0]
    best_min_samples = max_param[0][1]
    best_coef = max_param[0][2]

    best_max_num = max_param[1][0]
    best_num_cluster = max_param[1][1]
    # 输出结果
    print(f'聚类最佳参数为： \n epsilon: {best_eps}： \n min_samples: {best_min_samples}' +
                f'\n 系数： {coef} \n 簇最大个体数 {best_max_num}' +
                f'\n 簇数： {best_num_cluster}')
                
    return best_eps, best_min_samples, best_coef


def bi_star_cluster(X, df, scaler, best_eps, best_min_samples, best_coef, \
                    cluster_targets=['Plx', 'B-V']):
    '''
    返回毕星团
    '''
    # 提取那些 Plx 落在 [20，22] 上的行号（索引号）
    columns = X.columns
    X_before = scaler.inverse_transform(X)
    X_before = pd.DataFrame(data=X_before, columns=columns)
    Plx_20_22 = X_before.loc[(20<=X_before['Plx']) & (X_before['Plx']<=22)].index
    # 聚类
    labels = clustering(X, best_eps, best_min_samples, best_coef, cluster_targets)
    # 聚类标签保存
    labels_raw = labels
    
    # 提取标签
    labels = labels[Plx_20_22]
    df = df.iloc[Plx_20_22]
    df['label'] = labels
    
    # 找到数量最多的 label，并返回每个 label 最大的数量
    clu_and_num = dict()
    for label in labels:
        if label == -1:
            continue
        clu_and_num[label] = clu_and_num.get(label, 0) + 1
    

    # 每个 label 对应的个体数
    for clu, num in clu_and_num.items():
        print(f'簇{clu}的数量： ', num, '个')

    
    clu_num_sort = sorted(clu_and_num.items(), key=lambda x:x[1], \
                          reverse=True)
    
    # 毕星团对应的聚类簇
    max_clu = clu_num_sort[0][0]
    
    bi_star_clu = df.loc[df['label']==max_clu]['HIP']
    return bi_star_clu.values, max_clu, labels_raw

    
        
if __name__ == '__main__':
    file_path = r'../附件/data_ml.pkl'
    X = pickle.load(open(file_path, 'rb'))
    
    # 参数网格筛选
    epsilon_list = [1, 1.5, 2, 2.5, 3, 3.5]
    min_samples_list = [2, 2.5, 3, 3.5, 4]
    coef_list = [1.5, 2, 2.5, 3.0, 3.5, 4]
    
    scaler_model_path = r'../附件/scaler_model.pkl'
    scaler = pickle.load(open(scaler_model_path, 'rb'))
    
    data = X.copy()
    # 返回最佳参数，并输出结果
    best_eps, best_min_samples,\
                    best_coef = cluster_param_select(data, \
                                                    scaler, \
                                                    epsilon_list, \
                                                    min_samples_list, \
                                                    coef_list)
    # 程序可能会修改数据，所以需要还原。                                
    data = X.copy()

    path = r'../附件/星表数据.txt'
    # 读取 txt 文件
    df = pd.read_csv(path, delimiter=' +')
    # 使用最佳参数聚类，以筛选毕星团
    bi_star_clu, bi_clu_label, cluster_res = bi_star_cluster(data, df, \
                                                    scaler, best_eps, \
                                                    best_min_samples, \
                                                    best_coef)
    # 输出结果
    print(bi_star_clu)
    # 保存结果
    pickle.dump(bi_star_clu, open(r'../附件/bi_star_cluster_HIP.pkl', 'wb'))
    pickle.dump(bi_clu_label, open(r'../附件/bi_star_cluster_label.pkl', 'wb'))
    pickle.dump(cluster_res, open(r'../附件/clustering_result.pkl', 'wb'))
        
    
